#! /usr/bin/env python3
###########################################################################
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
###########################################################################
import argparse
import json
from math import floor
import os.path
import sys
import xml.etree.ElementTree as etree

_name = "Name"
_bram = 'BRAM_18k'
_interval = "Interval"
_clock_period = 'Clock Period'
_clock_rate = 'Clock Rate (MHz)'
_tns = 'TNS'
_wns = 'WNS'
_bad_perf = 'Bad Perf'
_errors = 'Errors'
_resource_titles = {
    'BRAM': _bram,
}
_timing_titles = [
    ('CP_FINAL', _clock_period),
    ('TNS_FINAL', _tns),
    ('WNS_FINAL', _wns),
]

_nanotube_hls_build_json = "nanotube_hls_build.json"

bram_width = 36
bram_depth = 512

###########################################################################

def exception_str(e):
    if isinstance(e, OSError):
        return e.strerror
    return str(e)

###########################################################################

class column:
    def __init__(self, name, **kwargs):
        self.name = name
        self.kwargs = kwargs

class name_column(column):
    def init_val(self):
        return None

    def accumulate(self, acc, val_str):
        return None

    def format(self, acc):
        return 'Top'

class int_sum_column(column):
    def init_val(self):
        return 0

    def accumulate(self, acc, val_str):
        return acc + int(val_str)

    def format(self, acc):
        return str(acc)

class int_max_column(column):
    def init_val(self):
        return 0

    def accumulate(self, acc, val_str):
        return max(acc, int(val_str))

    def format(self, acc):
        return str(acc)

class float_sum_column(column):
    def init_val(self):
        return 0

    def accumulate(self, acc, val_str):
        return acc + float(val_str)

    def format(self, acc):
        return "%5.3f" % (acc,)

class float_min_column(column):
    def init_val(self):
        return self.kwargs.get('init')

    def accumulate(self, acc, val_str):
        if acc == None:
            return float(val_str)
        return min(acc, float(val_str))

    def format(self, acc):
        if acc == None:
            return ""
        return "%5.3f" % (acc,)

class float_max_column(column):
    def init_val(self):
        return 0

    def accumulate(self, acc, val_str):
        return max(acc, float(val_str))

    def format(self, acc):
        return "%5.3f" % (acc,)

_columns = [
    name_column(_name),
    int_sum_column(_bram),
    int_sum_column('CLB'),
    int_sum_column('DSP'),
    int_sum_column('FF'),
    int_sum_column('LATCH'),
    int_sum_column('LUT'),
    int_sum_column('SRL'),
    int_sum_column('URAM'),
    int_max_column('Interval'),
    float_max_column(_clock_period),
    float_min_column(_clock_rate),
    float_sum_column(_tns),
    float_min_column(_wns),
    int_sum_column(_bad_perf),
    int_sum_column(_errors),
]
 
###########################################################################

class table:
    def __init__(self):
        self.__rows = []
        self.__header_data = {}
        self.__col_names = []
        self.__col_idx = {}

    def add_header(self):
        # Note, this carries a live reference to the header data so it
        # is always up-to-date.
        self.__rows.append((0, self.__header_data))

    def add_col(self, key):
        idx = self.__col_idx.get(key)
        if idx == None:
            idx = len(self.__col_names)
            self.__col_names.append(key)
            self.__col_idx[key] = idx
            self.__header_data[key] = key

    def add_data_row(self, data):
        data = dict(data)

        # Add any columns which have not yet been added.
        for key,val in data.items():
            self.add_col(key)

        # Add the row.
        self.__rows.append((0, data))

    def add_ruler(self):
        self.__rows.append((1, {}))

    def write(self, fh):
        # Calculate the column widths.
        col_widths = {}
        for r in self.__rows:
            for key,val in r[1].items():
                width = len(str(val))
                col_widths[key] = max(col_widths.get(key, 0), width)

        fmt = ["| %s |", "+-%s-+"]
        sep = [" | ", "-+-"]
        fill = ["", "-"]

        # Write the lines.
        for style, data in self.__rows:
            fields = []
            for idx, col_name in enumerate(self.__col_names):
                width = col_widths[col_name]
                val = data.get(col_name)
                if val == None:
                    val = fill[style]*width
                fields.append("%*s" % (width, val))
            print(fmt[style]%(sep[style].join(fields)))

###########################################################################

class app:
    def __init__(self, argv):
        self.__argv = argv
        self.__modules = []
        self.__fifos = []

    def parse_args(self):
        p = argparse.ArgumentParser()

        p.add_argument('--short', '-s', action="store_true",
                       help="Use the short output format.")
        p.add_argument('inputs', nargs='+',
                       help='The input files to parse.')

        self.__args = p.parse_args(self.__argv[1:])

    def find_one_node(self, et, path):
        nodes = list(et.findall(path))
        if len(nodes) != 1:
            name = self.__argv[0]
            if len(nodes) == 0:
                sys.stderr.write("%s: Could not find node '%s'.\n" %
                                 (name, path))
            else:
                sys.stderr.write("%s: Too many nodes '%s'.\n" %
                                 (name, path))
            sys.exit(1)
        return nodes[0]

    def process_impl_xml(self, line, et):
        top_path = './RtlModules/RtlModule[@IS_TOP="1"]'
        top = self.find_one_node(et, top_path)
        top_name = top.get('MODULENAME')

        resources_path = "./AreaReport/Resources"
        resources_node = self.find_one_node(et, resources_path)

        line[_name] = top_name
        for node in resources_node:
            title = _resource_titles.get(node.tag, node.tag)
            line[title] = node.text

        timing_path = "./TimingReport"
        timing_node = self.find_one_node(et, timing_path)
        for (t_name, t_title) in _timing_titles:
            node = self.find_one_node(timing_node, './'+t_name)
            line[t_title] = node.text
            if t_title == _tns and float(node.text) < 0:
                line[_bad_perf] = 1
            if t_title == _clock_period:
                line[_clock_rate] = 1000/float(node.text)

    def read_xml_file(self, name):
        line = {}
        et = etree.parse(name)
        self.process_impl_xml(line, et)
        self.__modules.append(line)

    def process_syn_xml(self, line, et):
        ii_path = './PerformanceEstimates/SummaryOfOverallLatency/Interval-max'
        ii_node = self.find_one_node(et, ii_path)
        line[_interval] = ii_node.text
        if int(ii_node.text) != 1:
            line[_bad_perf] = 1

    def read_hls_proj(self, dirname):
        line = {}

        line[_name] = os.path.basename(dirname)

        impl_report_rel = "solution1/impl/report/verilog/export_syn.xml"
        impl_report_abs = os.path.join(dirname, impl_report_rel)
        try:
            et = etree.parse(impl_report_abs)
            self.process_impl_xml(line, et)
        except Exception as e:
            sys.stderr.write("Error processing %r: %s\n" %
                             (impl_report_abs, exception_str(e)))
            line[_errors] = 1

        syn_report_rel = "solution1/syn/report/csynth.xml"
        syn_report_abs = os.path.join(dirname, syn_report_rel)
        try:
            et = etree.parse(syn_report_abs)
            self.process_syn_xml(line, et)
        except Exception as e:
            sys.stderr.write("Error processing %r: %s\n" %
                             (syn_report_abs, exception_str(e)))
            line[_errors] = 1

        self.__modules.append(line)

    def process_nanotube_json(self, top):
        for e in top['channels']:
            name = "fifo_"+str(e['channel_id'])
            elem_size = int(e['elem_size'])
            elem_bits = elem_size*8
            num_elem = int(e['num_elem'])
            total_bits = elem_bits * num_elem
            bram_cols = floor((elem_bits + bram_width-1) / bram_width)
            bram_rows = floor((num_elem + bram_depth-1) / bram_depth)
            num_bram = bram_rows * bram_cols
            self.__fifos.append((
                (_name, name),
                ('Elem size', elem_size),
                ('Num elem', num_elem),
                ('BRAM rows', bram_rows),
                ('BRAM cols', bram_cols),
                ('Num BRAMs', num_bram),
            ))

    def read_hls_build(self, path):
        fh = open(os.path.join(path, _nanotube_hls_build_json))
        top = json.load(fh)
        fh.close()
        self.process_nanotube_json(top)
        for e in top['stages']:
            dir_name = 'stage_'+str(e['thread_id'])
            self.read_hls_proj(os.path.join(path, dir_name))

    def read_json_file(self, filename):
        fh = open(filename)
        t = json.load(fh)
        fh.close()
        self.process_nanotube_json(t)

    def read_file(self, path):
        if path.endswith(".json"):
            self.read_json_file(path)
        elif path.endswith(".xml"):
            self.read_xml_file(path)
        else:
            sys.stderr.write("%s: Unknown file extension in '%s'.\n" %
                             (self.__argv[0], path))
            sys.exit(1)

    def read_dir(self, path):
        if os.path.exists(os.path.join(path, "hls.app")):
            self.read_hls_proj(path)
        elif os.path.exists(os.path.join(path, _nanotube_hls_build_json)):
            self.read_hls_build(path)
        else:
            sys.stderr.write("%s: Unrecognised directory '%s'.\n" %
                             (self.__argv[0], path))
            sys.exit(1)

    def read_inputs(self):
        for name in self.__args.inputs:
            if not os.path.exists(name):
                sys.stderr.write("%s: Missing file or directory '%s'.\n" %
                                 (self.__argv[0], name))
                sys.exit(1)

            if os.path.isfile(name):
                self.read_file(name)

            elif os.path.isdir(name):
                self.read_dir(name)

            else:
                sys.stderr.write("%s: Unsupported inode type '%s'.\n" %
                                 (self.__argv[0], name))
                sys.exit(1)

    def write_summary(self):
        print("FIFOs")
        print("-----")
        print("")

        # Calculate the FIFO totals.
        totals = {_name:"Totals"}
        for f in self.__fifos:
            for key, val in f:
                if key != _name:
                    totals[key] = totals.get(key, 0) + int(val)
        fifo_brams = totals.get('Num BRAMs', 0)

        # Write the FIFO table.
        t = table()
        t.add_ruler()
        t.add_header()
        if not self.__args.short:
            t.add_ruler()
            for f in self.__fifos:
                t.add_data_row(f)
        t.add_ruler()
        t.add_data_row(totals)
        t.add_ruler()
        t.write(sys.stdout)
        print("")

        print("Modules")
        print("-------")
        print("")

        modules = list(self.__modules)
        modules.append({
            _name: "FIFOs",
            _bram: fifo_brams,
        })

        # Calculate the top-level module values.
        col_index = dict((c.name, i) for i,c in enumerate(_columns))
        top = list([c.name, c.init_val()] for c in _columns)
        keys = set()
        for m in modules:
            for c in _columns:
                val = m.get(c.name)
                if val != None:
                    idx = col_index[c.name]
                    top[idx][1] = c.accumulate(top[idx][1], val)
                    keys.add(c.name)

        for idx,c in enumerate(_columns):
            top[idx][1] = c.format(top[idx][1])
        top = [e for e in top if e[0] in keys]

        # Create the module table.
        t = table()
        t.add_ruler()
        t.add_header()
        for c in _columns:
            if c.name in keys:
                t.add_col(c.name)

        if not self.__args.short:
            t.add_ruler()
            for m in modules:
                for c in _columns:
                    val = m.get(c.name)
                    if val == None:
                        continue
                    if not isinstance(val, str):
                        m[c.name] = c.format(val)
                t.add_data_row(m)
        t.add_ruler()
        t.add_data_row(top)
        t.add_ruler()
        
        # Write the module table.
        t.write(sys.stdout)
        print("")

    def run(self):
        self.parse_args()
        self.read_inputs()
        sys.stderr.flush()
        self.write_summary()

app(sys.argv).run()

###########################################################################
